import torch
import transducer_loss_cuda
import transducer_joint_cuda

class TransducerJoint(torch.nn.Module):
    """Transducer joint
    Detail of this loss function can be found in: Sequence Transduction with Recurrent Neural 
    Networks

    Arguments:
        pack_output (bool, optional): whether to pack the output in a compact form with don't-care 
        data being removed. (default: False)
        relu (bool, optional): apply ReLU to the output of the joint operation. Requires opt=1  
        (default: False)
        dropout (bool, optional): apply dropout to the output of the joint operation. Requires opt=1  
        (default: False)
        opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a tiled algorithm. 
            (default: 1)
        fwd_tile_size (int, optional): tile size used in forward operation. This argument will be 
        ignored if opt != 1. (default: 4) 
        dropout_prob (float, optional): dropout probability. (default: 0.0)
        probe_mask (bool, optional): a flag used to probe the mask generated by ReLU and/or dropout
        operation. When this argument is set to True, the mask can be accessed through 
        self.mask_probe. (default: false)
    """

    def __init__(self, pack_output=False, relu=False, dropout=False, opt=1, fwd_tile_size=4, 
                    dropout_prob=0, probe_mask=False):
        super(TransducerJoint, self).__init__() 
        self.pack_output = pack_output
        self.relu = relu
        self.dropout = dropout
        self.dropout_prob = dropout_prob
        self.opt = opt
        self.fwd_tile_size = fwd_tile_size
        self.dummy_batch_offset = torch.empty(0)
        masked = self.relu or self.dropout
        self.mask_probe = [] if masked and probe_mask else None
        if masked and opt != 1:
            raise NotImplementedError("ReLU and dropout fusion is only supported with opt=1")


    def forward(self, f, g, f_len, g_len, batch_offset=None, packed_batch=0):
        """Forward operation of transducer joint

        Arguments:
            f (tensor): transcription vector from encode block of shape (B, T, H).
            g (tensor): prediction vector form predict block of shape (B, U, H).
            f_len (tensor): length of transcription vector for each batch.
            g_len (tensor): length of prediction vector minus 1 for each batch.
            batch_offset (tensor, optional): tensor containing the offset of each batch
                in the results. For example, batch offset can be obtained from: 
                batch_offset = torch.cumsum(f_len*g_len, dim=0)
                This argument is required if pack_output == True, and is ignored if 
                pack_output == False. (default: None)
            packed_batch (int, optional): the batch size after packing. This argument is 
                ignored if pack_output == False. (default: 0)
        """
        my_batch_offset = batch_offset if self.pack_output else self.dummy_batch_offset
        if self.pack_output and (batch_offset is None or packed_batch == 0):
            raise Exception("Please specify batch_offset and packed_batch when packing is enabled")
        dropout =  self.dropout and self.training    # only dropout for training
        return TransducerJointFunc.apply(f, g, f_len, g_len, self.pack_output, self.relu, dropout, 
                                            my_batch_offset, packed_batch, self.opt, 
                                            self.fwd_tile_size, self.dropout_prob, self.mask_probe)


class TransducerLoss(torch.nn.Module):
    """Transducer loss
    Detail of this loss function can be found in: Sequence Transduction with Recurrent Neural 
    Networks

    Arguments:
        fuse_softmax_backward (bool, optional) whether to fuse the backward of transducer loss with
            softmax. (default: True)
        opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a more optimized 
            algorithm. In some cases, opt=1 might fall back to opt=0. (default: 1)
        packed_input (bool, optional): whether to pack the output in a compact form with don't-care 
        data being removed. (default: False)
    """
    def __init__(self, fuse_softmax_backward=True, opt=1, packed_input=False):
        super(TransducerLoss, self).__init__() 
        self.fuse_softmax_backward = fuse_softmax_backward
        self.opt = opt
        self.packed_input = packed_input
        self.dummy_batch_offset = torch.empty(0)


    def forward(self, x, label, f_len, y_len, blank_idx, batch_offset=None, max_f_len=None, 
                debug_list=None):
        """Forward operation of transducer joint

        Arguments:
            x (tensor): input tensor to the loss function with a shape of (B, T, U, H).
            label (tensor): labels for the input data.
            f_len (tensor): lengths of the inputs in the time dimension for each batch.
            y_len (tensor): lengths of the labels for each batch.
            blank_idx (int): index for the null symbol.
            batch_offset (tensor, optional): tensor containing the offset of each batch
                in the input. For example, batch offset can be obtained from: 
                batch_offset = torch.cumsum(f_len*(y_len+1), dim=0)
                This argument is required if packed_input == True, and is ignored if 
                packed_input == False. (default: None)
            max_f_len (int, optional): maximum length of the input in the time dimension.
                For example, it can be obtained as 
                max_f_len = max(f_len)
                This argument is required if packed_input == True, and is ignored if 
                packed_input == False. (default: None)
                (default: None)
            debug_list (list, optional): when an empty list is supplied, Alpha and Beta generated 
                in the forward operation will be attached to this list for debug purpose. 
                (default: None)
        """
        if self.packed_input:
            if batch_offset is None or max_f_len is None:
                raise Exception("Please specify batch_offset and max_f_len when packing is \
                                    enabled") 
            my_batch_offset = batch_offset
            my_max_f_len = max_f_len
        else:
            my_batch_offset = self.dummy_batch_offset
            my_max_f_len = x.size(1)
        return TransducerLossFunc.apply(x, label, f_len, y_len, my_batch_offset, my_max_f_len, 
                                            blank_idx, self.fuse_softmax_backward, debug_list, 
                                            self.opt, self.packed_input)

class TransducerLossFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, label, f_len, y_len, batch_offset, max_f_len, blank_idx, 
                fuse_softmax_backward, debug_list, opt, packed_input):
        if fuse_softmax_backward == False:
            with torch.enable_grad():
                x = torch.nn.functional.log_softmax(x, dim=-1)
        else:
            x = torch.nn.functional.log_softmax(x, dim=-1)
        alpha, beta, loss = transducer_loss_cuda.forward(   x, label, f_len, y_len, batch_offset, 
                                                            max_f_len, blank_idx, opt, packed_input)
        if debug_list == []:
            debug_list += [alpha, beta]
        ctx.save_for_backward(x, alpha, beta, f_len, y_len, label, batch_offset)
        ctx.blank_idx = blank_idx
        ctx.fuse_softmax_backward = fuse_softmax_backward
        ctx.opt = opt
        ctx.packed_input = packed_input
        ctx.max_f_len = max_f_len
        return loss

    @staticmethod
    def backward(ctx, loss_grad):
        x, alpha, beta, f_len, y_len, label, batch_offset = ctx.saved_tensors
        x_grad = transducer_loss_cuda.backward( x, loss_grad, alpha, beta, f_len, y_len, label, 
                                                batch_offset, ctx.max_f_len, ctx.blank_idx, ctx.opt, 
                                                ctx.fuse_softmax_backward, ctx.packed_input)
        if ctx.fuse_softmax_backward == False:
            x_grad = x.backward(x_grad)
        return x_grad, None, None, None, None, None, None, None, None, None, None

class TransducerJointFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, f, g, f_len, g_len, pack_output, relu, dropout, batch_offset, packed_batch, 
                opt, fwd_tile_size, dropout_prob, mask_probe):
        h = transducer_joint_cuda.forward(f, g, f_len, g_len, batch_offset, packed_batch, opt, 
                                            pack_output, relu, dropout, dropout_prob, fwd_tile_size)
        masked = relu or dropout
        if masked:
            ctx.save_for_backward(h[1], f_len, g_len, batch_offset)
            if mask_probe is not None:
                mask_probe.append(h[1])
        else:
            ctx.save_for_backward(f_len, g_len, batch_offset)

        ctx.pack_output = pack_output
        ctx.masked = relu or dropout
        ctx.max_f_len = f.size(1)
        ctx.max_g_len = g.size(1)
        ctx.scale = 1 / (1-dropout_prob) if dropout and dropout_prob != 1 else 1
        return h[0]

    @staticmethod
    def backward(ctx, loss_grad):
        if ctx.masked:
            mask, f_len, g_len, batch_offset = ctx.saved_tensors
            inp = [loss_grad, mask]
        else:
            f_len, g_len, batch_offset = ctx.saved_tensors
            inp = [loss_grad]

        f_grad, g_grad = transducer_joint_cuda.backward(    inp, f_len, g_len, batch_offset, 
                                                            ctx.max_f_len, ctx.max_g_len, 
                                                            ctx.pack_output, ctx.scale)

        return f_grad, g_grad, None, None, None, None, None, None, None, None, None, None, None, \
                None, None, None


